M2.975 · Deep Learning · PAC1
2024-2 · Màster universitari en Ciència de dades (Data science)
Estudis de Informàtica, Multimèdia y Telecomunicació
Al llarg d'aquesta pràctica, implementarem diversos models de xarxes neuronals, utilitzant Keras i la base de dades Imagenette (versió 320px). Concretament, abordarem les tasques següents:
Consideracions generals :
Format del lliurament :
Al llarg d’aquesta pràctica, implementarem diversos models de xarxes neuronals per classificar les imatges de la base de dades Imagenette.
La base de dades Imagenette és un subconjunt de 10 classes fàcilment classificables de Imagenet, un projecte fonamental per avançar en la investigació sobre visió artificial i aprenentatge profund. Imagenette conté unes 13.000 imatges de diferents mides pertanyents a 10 categories (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute), cadascuna en una carpeta diferent.
Concretament en aquesta PAC, utilitzarem una versió (Imagenette2-320) que ha estat re-escalada, però mantenint la relació d’aspecte de cada imatge (s’han ajustat de manera que la dimensió menor de cada imatge sigui de 320 píxels). Això atenuarà la càrrega computacional dels algoritmes quan s’utilitzen bases de dades d’imatges, però mantenint una qualitat suficient necessària per als nostres experiments. Les dades venen separades en 2 conjunts, entrenament i validació.
Nota: A causa de l’ús d’imatges com dades d’aquesta pràctica, l’entrenament de cada exercici es pot retardar entre uns minuts i més de mitja hora mitjançant GPU (els temps utilitzant CPU són significativament més llargs). Es recomana realitzar la pràctica en l’entorn que ofereix la plataforma Kaggle, ja que ofereix un entorn gratuït amb 30 hores setmanals per a l’ús de la GPU.
Al llarg de tota la pràctica, per a la creació de les diferents xarxes, alternem l’ús del model Sequential i el model Functional de Keras a través de les seves classes Sequential i Model respectivament.
Es recomana la lectura detallada de la documentació dels dos models per dur a terme la realització de la pràctica.
Comencem per instal·lar i carregar les llibreries més rellevants:
# Instal·lem la darrera versió de Tensorflow (amb CUDA)
%pip install tensorflow[and-cuda]
# Importem Tensorflow
import tensorflow as tf
print("TensorFlow version : ", tf.__version__)
# Necessitarem GPU
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")
# Keras versió is 3.5.0
from tensorflow import keras
print("Keras version : ", keras.__version__)
TensorFlow version : 2.17.1 GPU is available Keras version : 3.5.0
# Importem els elements de Keras que utilitzarem amb més freqüència
from keras.utils import image_dataset_from_directory
from keras.layers import (Input, Dense, Dropout, Flatten, Conv2D, Conv2DTranspose,
MaxPooling2D, UpSampling2D, Rescaling, Resizing,
RandomFlip, RandomRotation)
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.optimizers import Adam
# Importem algunes llibreries que necessitarem per a la PAC
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
En aquesta secció explorarem la base de dades i prepararem la càrrega de les imatges per als models de les seccions següents.
Per crear la nostra base de dades, hem de descarregar el fitxer d’imatges del següent enllaç (és un fitxer .zip que ocupa aproximadament 340 Mb).
Nota: per descarregar el fitxer d'imatges, heu de iniciar la sessió amb l'usuari i la contrasenya de la UOC.
A partir d’aquí:
Si treballem en local, simplement hem de descomprimir el fitxer descarregat.
Si treballem des de Kaggle hem de pujar el notebook de l'enunciat a la plataforma (per això podeu seguir els 6 primers passos del següent article) i després, un cop pujat el notebook, expandir la barra lateral desplegable de la dreta i al menú 'Input' clickar el botó 'Upload' i pujar el fitxer descarregat prèviament. Després cal donar un nom a la base de dades i quan es carregui el fitxer ja tindreu accessible la base de dades a la ruta ../ input/.
Un cop tinguem la base de dades accessible, la inspeccionarem.
A la carpeta /images (si treballem a casa) o /kaggle/input/nom-base-de-dades/images (si treballem des de Kaggle) trobem 2 carpetes:
/train es troba el total de les imatges d'entrenament separades per classes (cada classe en una carpeta diferent)./val, es troba el total de les imatges de validació separades per classes (cada classe en una carpeta diferent).Com podem veure, tenim imatges per dur a terme l'entrenament i la validació dels models, però no tenim un conjunt de prova, el crearem durant aquesta primera secció.
Comencem per obtenir les dades i analitzar la seva estructura i característiques.
Primer, inspeccionarem l’organització de les dades.
/train i /val):
import os
from pathlib import Path
pac_dir = Path("/kaggle/input/uoc-dl-pac1-v2/images")
train_dir = pac_dir / "train"
val_dir = pac_dir / "val"
# Llistar classes a cada conjunt
train_classes = sorted([d.name for d in train_dir.iterdir() if d.is_dir()])
val_classes = sorted([d.name for d in val_dir.iterdir() if d.is_dir()])
print("Classes a train:", train_classes)
print("Classes a validation:", val_classes)
Classes a train: ['English_springer', 'French_horn', 'cassette_player', 'chain_saw', 'church', 'garbage_truck', 'gas_pump', 'golf_ball', 'parachute', 'tench'] Classes a validation: ['English_springer', 'French_horn', 'cassette_player', 'chain_saw', 'church', 'garbage_truck', 'gas_pump', 'golf_ball', 'parachute', 'tench']
# Verificar que les dues llistes de classes són iguals
if train_classes == val_classes:
print("Les classes de train i validation són iguals")
else:
print("Les classes de train i validation no són iguals")
Les classes de train i validation són iguals
# Comptar total d'imatges per conjunt i per classe
from pathlib import Path
def count_images(directory, classes):
"""
Compta el nombre d'imatges amb extensió .JPEG en diferents subdirectoris.
Parameters:
directory (str o Path): El directori principal que conté els subdirectoris.
classes (list): Una llista amb els noms dels subdirectoris que es vol analitzar.
Returns:
dict: Un diccionari on les claus són els noms de les classes
i els valors són el nombre d'imatges .JPEG trobades en cada classe.
"""
counts = {}
for class_name in classes:
class_dir = Path(directory) / class_name
images = list(class_dir.glob("*.JPEG"))
counts[class_name] = len(images)
return counts
# Imprimeix nombre d'imatges per classe a train
train_counts = count_images(train_dir, train_classes)
print("Nombre d'imatges per classe a train:", train_counts)
# Imprimeix nombre d'imatges per classe a val
val_counts = count_images(val_dir, train_classes)
print("\n Nombre d'imatges per classe a validation:", val_counts)
Nombre d'imatges per classe a train: {'English_springer': 955, 'French_horn': 956, 'cassette_player': 993, 'chain_saw': 858, 'church': 941, 'garbage_truck': 961, 'gas_pump': 931, 'golf_ball': 951, 'parachute': 960, 'tench': 963}
Nombre d'imatges per classe a validation: {'English_springer': 395, 'French_horn': 394, 'cassette_player': 357, 'chain_saw': 386, 'church': 409, 'garbage_truck': 389, 'gas_pump': 419, 'golf_ball': 399, 'parachute': 390, 'tench': 387}
# Representació gràfica de la distribució de classes
plt.figure(figsize=(10, 5))
plt.bar(train_counts.keys(), train_counts.values())
plt.title("Distribució d'imatges per clase - Train")
plt.xlabel("Classes")
plt.ylabel("Nombre d'imatges")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
plt.figure(figsize=(10, 5))
plt.bar(val_counts.keys(), val_counts.values())
plt.title("Distribució d'imatges per clase - Validation")
plt.xlabel("Classes")
plt.ylabel("Nombre d'imatges")
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
# Calcul del percentatge total
total_train = sum(train_counts.values())
total_val = sum(val_counts.values())
total_images = total_train + total_val
train_percentage = (total_train / total_images) * 100
val_percentage = (total_val / total_images) * 100
print(f"Percentatge d'imatges al conjunt Train: {train_percentage:.2f}%")
print(f"Percentatge d'imatges al conjunt Validation: {val_percentage:.2f}%")
Percentatge d'imatges al conjunt Train: 70.70% Percentatge d'imatges al conjunt Validation: 29.30%
Ara examinarem el format de les imatges per entendre la seva mida i el seu rang de valors. Visualitzarem algunes imatges d’exemple de cada classe.
import random
from PIL import Image
plt.figure(figsize=(15, 10))
for i, classe in enumerate(train_classes):
classe_dir = train_dir / classe
imatges = list(classe_dir.glob("*.JPEG"))
if imatges:
imatge_path = random.choice(imatges)
imatge = Image.open(imatge_path)
imatge_array = np.array(imatge)
ample, alt = imatge.size
valor_min = imatge_array.min()
valor_max = imatge_array.max()
plt.subplot(2, 5, i+1)
plt.imshow(imatge)
plt.title(f"{classe}\n{alt}x{ample}\n{valor_min}-{valor_max}")
plt.axis("off")
print(f"Classe: {classe}, Dimensions: {alt}x{ample}, Rang dinàmic: {valor_min} - {valor_max}")
plt.tight_layout()
plt.show()
Classe: English_springer, Dimensions: 320x480, Rang dinàmic: 0 - 255 Classe: French_horn, Dimensions: 426x320, Rang dinàmic: 0 - 255 Classe: cassette_player, Dimensions: 320x388, Rang dinàmic: 4 - 255 Classe: chain_saw, Dimensions: 426x320, Rang dinàmic: 0 - 255 Classe: church, Dimensions: 321x320, Rang dinàmic: 0 - 255 Classe: garbage_truck, Dimensions: 320x426, Rang dinàmic: 0 - 255 Classe: gas_pump, Dimensions: 320x426, Rang dinàmic: 0 - 255 Classe: golf_ball, Dimensions: 393x320, Rang dinàmic: 0 - 255 Classe: parachute, Dimensions: 320x451, Rang dinàmic: 0 - 255 Classe: tench, Dimensions: 320x426, Rang dinàmic: 0 - 255
A continuació, prepararem les dades per a l'entrenament amb Keras. Utilitzarem la funció **tf.keras.utils.image_dataset_from_directory()** de TensorFlow/Keras, que permet crear lots de dades etiquetats en funció dels directoris d'imatges organitzats per classe.
La documentació d'aquesta funció es troba tant al lloc web de Keras com al de TensorFlow.
Aquesta funció ens facilitarà generar conjunts d'entrenament, validació i prova a partir de les carpetes analitzades. Les imatges es re-dimensionaran a una mida fixa i s’organitzaran en lots (batch).
Especificacions: convertirem les imatges a la mida de 160 × 160 píxels i les agruparem en lots de 64 imatges. També separarem part de les dades de prova a partir de les de validació.
tf.keras.utils.image_dataset_from_directory () per generar 3 conjunts de dades a partir de les carpetes /train i /val:
/train/. Redimensiona les imatges a 160×160, amb batch_size=64 i label_mode="categorical" (10 categories)./val/, amb un validation_split de 0.5 usant subset="validation" a la funció i fixa un seed per a reproducibilitat. De nou redimensiona les imatges a 160x160 amb batch_size=64 i label_mode="categorical"./val/ usant subset="training" a la funció (amb el mateix seed per obtenir la partició complementària). Redimensiona també a 160×160 amb batch_size=64 i fes servir label_mode="categorical".subset="validation" i subset="training" és completament arbitrària i podria haver-se fet al revés. L'important és dividir les dades que es troben a la carpeta /val al 50% entre validació i test.
img_height, img_width = 160, 160
batch_size = 64
seed = 42
# Conjunt d'entrenament
train_dataset = tf.keras.utils.image_dataset_from_directory(
train_dir,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical"
)
Found 9469 files belonging to 10 classes.
# Conjunt de validació i test (usem el directori de validació d'Imagenette també per a test)
val_dataset = tf.keras.utils.image_dataset_from_directory(
val_dir,
validation_split=0.5,
subset="validation",
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical"
)
test_dataset = tf.keras.utils.image_dataset_from_directory(
val_dir,
validation_split=0.5,
subset="training",
seed=seed,
image_size=(img_height, img_width),
batch_size=batch_size,
label_mode="categorical"
)
Found 3925 files belonging to 10 classes. Using 1962 files for validation. Found 3925 files belonging to 10 classes. Using 1963 files for training.
# Comprovació dels resultats
def imprimeix_info(dataset, nom):
for imatges, etiquetes in dataset.take(1):
print(f"{nom}: Imatges {imatges.shape}, Etiquetes {etiquetes.shape}")
imprimeix_info(train_dataset, "Train")
imprimeix_info(val_dataset, "Validation")
imprimeix_info(test_dataset, "Test")
Train: Imatges (64, 160, 160, 3), Etiquetes (64, 10) Validation: Imatges (64, 160, 160, 3), Etiquetes (64, 10) Test: Imatges (64, 160, 160, 3), Etiquetes (64, 10)
Com a primer model, entrenarem una xarxa neuronal completament connectada (xarxa densa o
Arquitectura proposada: Utilitzarem l’API funcional de Keras (classe Model) per construir la xarxa. Utilitzarem les capes Resizing i Rescaling de Keras per preparar les imatges, seguides de Flatten per aplanar-les, i diverses capes Dense intercalades amb Dropout pel classificador.
En aquesta secció utilitzarem les capes Resizing, Rescaling, Flatten, Dense i Dropout de Keras.
summary()
# Construcció del model ANN completament connectat (model funcional)
from keras import layers, Model
def ANN_model():
# Cada d'entrada
input_shape = (160, 160, 3)
inputs = Input(shape=input_shape)
# 160x160 -> 32x32
x = layers.Resizing(32, 32)(inputs)
# Reescalat valors píxel entre 0 i 1
x = layers.Rescaling(1./255)(x)
# Flatten (32x32x3 = 3072 elements)
x = layers.Flatten()(x)
# Capa densa de 3072*2/3 = 2048 elements, activació ReLu i Dropout
x = layers.Dense(2048, activation='relu')(x)
x = layers.Dropout(0.5)(x)
# Capa densa de 2048*1/2 = 1024 elements, activació ReLu i Dropout
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.5)(x)
# Capa de sortida amb 10 neurones (una per classe) i softmax (classificació multiclasse)
outputs = layers.Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=outputs)
return model
model = ANN_model()
model.summary()
Model: "functional"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 160, 160, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ resizing (Resizing) │ (None, 32, 32, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ rescaling (Rescaling) │ (None, 32, 32, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ flatten (Flatten) │ (None, 3072) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense (Dense) │ (None, 2048) │ 6,293,504 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout (Dropout) │ (None, 2048) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_1 (Dense) │ (None, 1024) │ 2,098,176 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_1 (Dropout) │ (None, 1024) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_2 (Dense) │ (None, 10) │ 10,250 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 8,401,930 (32.05 MB)
Trainable params: 8,401,930 (32.05 MB)
Non-trainable params: 0 (0.00 B)
Procedim a compilar i entrenar el model:
Comencem amb lr = 1e-3
# Compilació del model
lr = 1e-3
# Guardem les dades de l'entrenament i eval per fer el plot en un diccionari
hist = {}
eval_results = {}
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Definir callbacks: EarlyStopping
callbacks = [
EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
]
# Entrenament del model
hist = model.fit(train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks)
Epoch 1/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 19s 107ms/step - accuracy: 0.1512 - loss: 3.1145 - val_accuracy: 0.2971 - val_loss: 2.0193 - learning_rate: 0.0010 Epoch 2/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.2405 - loss: 2.1078 - val_accuracy: 0.2951 - val_loss: 2.0005 - learning_rate: 0.0010 Epoch 3/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.2519 - loss: 2.0991 - val_accuracy: 0.3140 - val_loss: 1.9567 - learning_rate: 0.0010 Epoch 4/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.2563 - loss: 2.0524 - val_accuracy: 0.3359 - val_loss: 1.9592 - learning_rate: 0.0010 Epoch 5/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.2602 - loss: 2.0343 - val_accuracy: 0.3542 - val_loss: 1.9346 - learning_rate: 0.0010 Epoch 6/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.2801 - loss: 2.0207 - val_accuracy: 0.3792 - val_loss: 1.9003 - learning_rate: 0.0010 Epoch 7/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.2972 - loss: 1.9747 - val_accuracy: 0.3619 - val_loss: 1.9180 - learning_rate: 0.0010 Epoch 8/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.2900 - loss: 1.9908 - val_accuracy: 0.3481 - val_loss: 1.9231 - learning_rate: 0.0010 Epoch 9/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.2959 - loss: 1.9757 - val_accuracy: 0.3828 - val_loss: 1.8785 - learning_rate: 0.0010 Epoch 10/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3061 - loss: 1.9579 - val_accuracy: 0.3797 - val_loss: 1.9005 - learning_rate: 0.0010 Epoch 11/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.3212 - loss: 1.9391 - val_accuracy: 0.3542 - val_loss: 1.9211 - learning_rate: 0.0010 Epoch 12/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3135 - loss: 1.9382 - val_accuracy: 0.3899 - val_loss: 1.8628 - learning_rate: 0.0010 Epoch 13/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3378 - loss: 1.8907 - val_accuracy: 0.3782 - val_loss: 1.8826 - learning_rate: 0.0010 Epoch 14/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3439 - loss: 1.8889 - val_accuracy: 0.3364 - val_loss: 1.9070 - learning_rate: 0.0010 Epoch 15/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3328 - loss: 1.9042 - val_accuracy: 0.3721 - val_loss: 1.8842 - learning_rate: 0.0010 Epoch 16/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.3345 - loss: 1.9094 - val_accuracy: 0.3960 - val_loss: 1.8664 - learning_rate: 0.0010 Epoch 17/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.3348 - loss: 1.8767 - val_accuracy: 0.3940 - val_loss: 1.8529 - learning_rate: 0.0010 Epoch 18/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.3384 - loss: 1.8673 - val_accuracy: 0.3660 - val_loss: 1.8862 - learning_rate: 0.0010 Epoch 19/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 44ms/step - accuracy: 0.3554 - loss: 1.8419 - val_accuracy: 0.3879 - val_loss: 1.8871 - learning_rate: 0.0010 Epoch 20/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3473 - loss: 1.8462 - val_accuracy: 0.3802 - val_loss: 1.8547 - learning_rate: 0.0010 Epoch 21/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.3559 - loss: 1.8503 - val_accuracy: 0.3812 - val_loss: 1.8773 - learning_rate: 0.0010 Epoch 22/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3498 - loss: 1.8547 - val_accuracy: 0.3818 - val_loss: 1.8883 - learning_rate: 0.0010 Epoch 23/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3803 - loss: 1.7884 - val_accuracy: 0.4062 - val_loss: 1.8068 - learning_rate: 2.0000e-04 Epoch 24/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3981 - loss: 1.7487 - val_accuracy: 0.3986 - val_loss: 1.7960 - learning_rate: 2.0000e-04 Epoch 25/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4092 - loss: 1.7265 - val_accuracy: 0.4134 - val_loss: 1.7944 - learning_rate: 2.0000e-04 Epoch 26/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4061 - loss: 1.7136 - val_accuracy: 0.4113 - val_loss: 1.7863 - learning_rate: 2.0000e-04 Epoch 27/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4111 - loss: 1.7017 - val_accuracy: 0.4123 - val_loss: 1.7826 - learning_rate: 2.0000e-04 Epoch 28/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4186 - loss: 1.6887 - val_accuracy: 0.4149 - val_loss: 1.7818 - learning_rate: 2.0000e-04 Epoch 29/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4113 - loss: 1.6845 - val_accuracy: 0.4088 - val_loss: 1.7711 - learning_rate: 2.0000e-04 Epoch 30/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4160 - loss: 1.7014 - val_accuracy: 0.4149 - val_loss: 1.7756 - learning_rate: 2.0000e-04 Epoch 31/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.4146 - loss: 1.6822 - val_accuracy: 0.4006 - val_loss: 1.7788 - learning_rate: 2.0000e-04 Epoch 32/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4159 - loss: 1.6855 - val_accuracy: 0.4128 - val_loss: 1.7786 - learning_rate: 2.0000e-04 Epoch 33/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4175 - loss: 1.6632 - val_accuracy: 0.4139 - val_loss: 1.7718 - learning_rate: 2.0000e-04 Epoch 34/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.4366 - loss: 1.6397 - val_accuracy: 0.4179 - val_loss: 1.7674 - learning_rate: 2.0000e-04 Epoch 35/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.4227 - loss: 1.6611 - val_accuracy: 0.4159 - val_loss: 1.7596 - learning_rate: 2.0000e-04 Epoch 36/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4328 - loss: 1.6531 - val_accuracy: 0.4113 - val_loss: 1.7780 - learning_rate: 2.0000e-04 Epoch 37/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4309 - loss: 1.6401 - val_accuracy: 0.4123 - val_loss: 1.7831 - learning_rate: 2.0000e-04 Epoch 38/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4337 - loss: 1.6404 - val_accuracy: 0.4128 - val_loss: 1.7684 - learning_rate: 2.0000e-04 Epoch 39/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4312 - loss: 1.6471 - val_accuracy: 0.4154 - val_loss: 1.7762 - learning_rate: 2.0000e-04 Epoch 40/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4377 - loss: 1.6358 - val_accuracy: 0.4271 - val_loss: 1.7641 - learning_rate: 2.0000e-04 Epoch 41/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4440 - loss: 1.5960 - val_accuracy: 0.4200 - val_loss: 1.7743 - learning_rate: 4.0000e-05 Epoch 42/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4535 - loss: 1.6008 - val_accuracy: 0.4169 - val_loss: 1.7741 - learning_rate: 4.0000e-05 Epoch 43/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4515 - loss: 1.5936 - val_accuracy: 0.4210 - val_loss: 1.7727 - learning_rate: 4.0000e-05 Epoch 44/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4506 - loss: 1.5895 - val_accuracy: 0.4200 - val_loss: 1.7711 - learning_rate: 4.0000e-05 Epoch 45/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4566 - loss: 1.5808 - val_accuracy: 0.4134 - val_loss: 1.7735 - learning_rate: 4.0000e-05
# Funció auxiliar per graficar accuracy i loss
def grafica_accuracy_loss(hist, lr):
"""
Funció auxiliar per graficar les corbes de loss i accuracy de l'entrenament i validació.
"""
epochs = range(1, len(hist.history['loss']) + 1)
plt.figure(figsize=(12, 5))
# Gràfic de la pèrdua (loss)
plt.subplot(1, 2, 1)
plt.plot(epochs, hist.history['loss'], label='Pèrdua Entrenament')
plt.plot(epochs, hist.history['val_loss'], label='Pèrdua Validació')
plt.title(f"Pèrdua (Learning rate = {lr})")
plt.xlabel("Època")
plt.ylabel("Pèrdua")
plt.legend()
# Gràfic de l'accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, hist.history['accuracy'], label='Accuracy Entrenament')
plt.plot(epochs, hist.history['val_accuracy'], label='Accuracy Validació')
plt.title(f"Accuracy (Learning rate = {lr})")
plt.xlabel("Època")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.show()
grafica_accuracy_loss(hist, lr)
# Validació al conjunt de prova
test_loss, test_accuracy = model.evaluate(test_dataset)
print(f"Pèrdua al conjunt de test: {test_loss:.4f}")
print(f"Accuracy al conjunt de test: {test_accuracy:.4f}")
31/31 ━━━━━━━━━━━━━━━━━━━━ 3s 83ms/step - accuracy: 0.3864 - loss: 1.8106 Pèrdua al conjunt de test: 1.8185 Accuracy al conjunt de test: 0.3923
Repetim amb lr = 1e-4
lr = 1e-4
# Crear un nou model sense pesos previs
model = ANN_model()
# Compilació del model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Definir callbacks: EarlyStopping
# Fet anteriorment
# Entrenament del model
hist = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks
)
grafica_accuracy_loss(hist, lr)
# Validació al conjunt de prova
test_loss, test_accuracy = model.evaluate(test_dataset)
print(f"Pèrdua al conjunt de prova: {test_loss:.4f}")
print(f"Accuracy al conjunt de prova: {test_accuracy:.4f}")
Epoch 1/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 10s 55ms/step - accuracy: 0.1723 - loss: 2.3385 - val_accuracy: 0.2895 - val_loss: 2.0008 - learning_rate: 1.0000e-04 Epoch 2/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.2747 - loss: 2.0615 - val_accuracy: 0.3496 - val_loss: 1.9361 - learning_rate: 1.0000e-04 Epoch 3/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3096 - loss: 1.9859 - val_accuracy: 0.3624 - val_loss: 1.8784 - learning_rate: 1.0000e-04 Epoch 4/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 41ms/step - accuracy: 0.3363 - loss: 1.9284 - val_accuracy: 0.3812 - val_loss: 1.8274 - learning_rate: 1.0000e-04 Epoch 5/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.3449 - loss: 1.8995 - val_accuracy: 0.3981 - val_loss: 1.7952 - learning_rate: 1.0000e-04 Epoch 6/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.3636 - loss: 1.8487 - val_accuracy: 0.4113 - val_loss: 1.7742 - learning_rate: 1.0000e-04 Epoch 7/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 41ms/step - accuracy: 0.3792 - loss: 1.8188 - val_accuracy: 0.4037 - val_loss: 1.7671 - learning_rate: 1.0000e-04 Epoch 8/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3864 - loss: 1.7836 - val_accuracy: 0.4179 - val_loss: 1.7521 - learning_rate: 1.0000e-04 Epoch 9/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3978 - loss: 1.7556 - val_accuracy: 0.4225 - val_loss: 1.7349 - learning_rate: 1.0000e-04 Epoch 10/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4133 - loss: 1.7282 - val_accuracy: 0.4179 - val_loss: 1.7151 - learning_rate: 1.0000e-04 Epoch 11/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4144 - loss: 1.7089 - val_accuracy: 0.4409 - val_loss: 1.6924 - learning_rate: 1.0000e-04 Epoch 12/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4320 - loss: 1.6852 - val_accuracy: 0.4393 - val_loss: 1.6810 - learning_rate: 1.0000e-04 Epoch 13/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4445 - loss: 1.6540 - val_accuracy: 0.4424 - val_loss: 1.6686 - learning_rate: 1.0000e-04 Epoch 14/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.4547 - loss: 1.6074 - val_accuracy: 0.4470 - val_loss: 1.6673 - learning_rate: 1.0000e-04 Epoch 15/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4566 - loss: 1.6093 - val_accuracy: 0.4388 - val_loss: 1.6705 - learning_rate: 1.0000e-04 Epoch 16/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4662 - loss: 1.5775 - val_accuracy: 0.4393 - val_loss: 1.6650 - learning_rate: 1.0000e-04 Epoch 17/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4692 - loss: 1.5637 - val_accuracy: 0.4562 - val_loss: 1.6496 - learning_rate: 1.0000e-04 Epoch 18/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.4885 - loss: 1.5399 - val_accuracy: 0.4618 - val_loss: 1.6296 - learning_rate: 1.0000e-04 Epoch 19/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.4959 - loss: 1.5036 - val_accuracy: 0.4577 - val_loss: 1.6304 - learning_rate: 1.0000e-04 Epoch 20/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.5115 - loss: 1.4722 - val_accuracy: 0.4526 - val_loss: 1.6520 - learning_rate: 1.0000e-04 Epoch 21/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 44ms/step - accuracy: 0.5041 - loss: 1.4722 - val_accuracy: 0.4638 - val_loss: 1.6274 - learning_rate: 1.0000e-04 Epoch 22/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.5151 - loss: 1.4457 - val_accuracy: 0.4546 - val_loss: 1.6481 - learning_rate: 1.0000e-04 Epoch 23/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.5296 - loss: 1.4144 - val_accuracy: 0.4613 - val_loss: 1.6213 - learning_rate: 1.0000e-04 Epoch 24/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.5340 - loss: 1.3880 - val_accuracy: 0.4730 - val_loss: 1.6289 - learning_rate: 1.0000e-04 Epoch 25/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.5380 - loss: 1.3923 - val_accuracy: 0.4633 - val_loss: 1.6322 - learning_rate: 1.0000e-04 Epoch 26/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.5515 - loss: 1.3516 - val_accuracy: 0.4475 - val_loss: 1.6568 - learning_rate: 1.0000e-04 Epoch 27/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.5541 - loss: 1.3311 - val_accuracy: 0.4602 - val_loss: 1.6206 - learning_rate: 1.0000e-04 Epoch 28/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.5682 - loss: 1.2938 - val_accuracy: 0.4709 - val_loss: 1.6179 - learning_rate: 1.0000e-04 Epoch 29/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.5646 - loss: 1.2924 - val_accuracy: 0.4725 - val_loss: 1.6460 - learning_rate: 1.0000e-04 Epoch 30/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5807 - loss: 1.2565 - val_accuracy: 0.4669 - val_loss: 1.6699 - learning_rate: 1.0000e-04 Epoch 31/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.5846 - loss: 1.2489 - val_accuracy: 0.4791 - val_loss: 1.6334 - learning_rate: 1.0000e-04 Epoch 32/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5901 - loss: 1.2216 - val_accuracy: 0.4704 - val_loss: 1.6593 - learning_rate: 1.0000e-04 Epoch 33/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.6031 - loss: 1.1776 - val_accuracy: 0.4801 - val_loss: 1.6378 - learning_rate: 1.0000e-04 Epoch 34/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.6336 - loss: 1.1137 - val_accuracy: 0.4862 - val_loss: 1.6321 - learning_rate: 2.0000e-05 Epoch 35/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6592 - loss: 1.0442 - val_accuracy: 0.4837 - val_loss: 1.6270 - learning_rate: 2.0000e-05 Epoch 36/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6667 - loss: 1.0244 - val_accuracy: 0.4852 - val_loss: 1.6374 - learning_rate: 2.0000e-05 Epoch 37/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.6726 - loss: 1.0050 - val_accuracy: 0.4883 - val_loss: 1.6336 - learning_rate: 2.0000e-05 Epoch 38/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6714 - loss: 1.0072 - val_accuracy: 0.4888 - val_loss: 1.6459 - learning_rate: 2.0000e-05
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 43ms/step - accuracy: 0.4431 - loss: 1.6269 Pèrdua al conjunt de prova: 1.6378 Accuracy al conjunt de prova: 0.4412
Acabem amb lr = 1e-5
lr = 1e-5
# Crear un nou model sense pesos previs
model = ANN_model()
# Compilació del model
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss='categorical_crossentropy',
metrics=['accuracy']
)
# Definir callbacks: EarlyStopping
# Fet anteriorment
# Entrenament del model
hist = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks
)
grafica_accuracy_loss(hist, lr)
# Validació al conjunt de prova
test_loss, test_accuracy = model.evaluate(test_dataset)
print(f"Pèrdua al conjunt de prova: {test_loss:.4f}")
print(f"Accuracy al conjunt de prova: {test_accuracy:.4f}")
Epoch 1/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 11s 55ms/step - accuracy: 0.1210 - loss: 2.4957 - val_accuracy: 0.2503 - val_loss: 2.1291 - learning_rate: 1.0000e-05 Epoch 2/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.1963 - loss: 2.2288 - val_accuracy: 0.2808 - val_loss: 2.0486 - learning_rate: 1.0000e-05 Epoch 3/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 42ms/step - accuracy: 0.2440 - loss: 2.1201 - val_accuracy: 0.2977 - val_loss: 1.9966 - learning_rate: 1.0000e-05 Epoch 4/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.2769 - loss: 2.0597 - val_accuracy: 0.3257 - val_loss: 1.9560 - learning_rate: 1.0000e-05 Epoch 5/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 41ms/step - accuracy: 0.2894 - loss: 2.0211 - val_accuracy: 0.3313 - val_loss: 1.9306 - learning_rate: 1.0000e-05 Epoch 6/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.3069 - loss: 1.9892 - val_accuracy: 0.3517 - val_loss: 1.9063 - learning_rate: 1.0000e-05 Epoch 7/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3131 - loss: 1.9672 - val_accuracy: 0.3517 - val_loss: 1.8882 - learning_rate: 1.0000e-05 Epoch 8/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3239 - loss: 1.9496 - val_accuracy: 0.3660 - val_loss: 1.8720 - learning_rate: 1.0000e-05 Epoch 9/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3412 - loss: 1.9106 - val_accuracy: 0.3690 - val_loss: 1.8557 - learning_rate: 1.0000e-05 Epoch 10/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 6s 43ms/step - accuracy: 0.3441 - loss: 1.8931 - val_accuracy: 0.3710 - val_loss: 1.8446 - learning_rate: 1.0000e-05
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 42ms/step - accuracy: 0.2736 - loss: 2.1254 Pèrdua al conjunt de prova: 2.1402 Accuracy al conjunt de prova: 0.2547
Ara implementarem una xarxa neuronal convolucional bàsica (CNN), que sol ser molt més eficaç per a la classificació de les imatges. Les CNN aprofiten l'estructura espacial de les dades mitjançant capes específiques que permeten extreure les característiques locals abans de la classificació.
Arquitectura proposada: Utilitzarem un model de Keras Sequential per a aquesta CNN. Consistirà en un bloc d’extractor de característiques i, a continuació, un classificador dens similar a l'anterior però més petit.
En aquesta secció utilitzarem les capes Conv2D, MaxPooling2D, Dense i Dropout de Keras.
Es proporciona el codi del classificador ja implementat.
# Definició del model CNN
cnn_model = keras.Sequential([
keras.layers.InputLayer(shape=(img_height, img_width, 3)),
Rescaling(1./255),
Conv2D(16, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Dropout(0.2),
Dense(64, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
], name="CNN_model")
cnn_model.summary()
Model: "CNN_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ rescaling_3 (Rescaling) │ (None, 160, 160, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d (Conv2D) │ (None, 160, 160, 16) │ 448 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d (MaxPooling2D) │ (None, 80, 80, 16) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_1 (Conv2D) │ (None, 80, 80, 32) │ 4,640 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_1 (MaxPooling2D) │ (None, 40, 40, 32) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_2 (Conv2D) │ (None, 40, 40, 64) │ 18,496 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_2 (MaxPooling2D) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_6 (Dropout) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_9 (Dense) │ (None, 20, 20, 64) │ 4,160 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_7 (Dropout) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_10 (Dense) │ (None, 20, 20, 10) │ 650 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 28,394 (110.91 KB)
Trainable params: 28,394 (110.91 KB)
Non-trainable params: 0 (0.00 B)
# Compilar model CNN
cnn_model.compile(optimizer=Adam(learning_rate=1e-4),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Definir callbacks EarlyStopping y ReduceLROnPlateau
callbacks = [
EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
]
# Entrenament del model CNN
hist = cnn_model.fit(
train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks
)
Epoch 1/100
--------------------------------------------------------------------------- ValueError Traceback (most recent call last) <ipython-input-25-c02148c35999> in <cell line: 2>() 1 # Entrenament del model CNN ----> 2 hist = cnn_model.fit( 3 train_dataset, 4 validation_data=val_dataset, 5 epochs=100, /usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py in error_handler(*args, **kwargs) 120 # To get the full stack trace, call: 121 # `keras.config.disable_traceback_filtering()` --> 122 raise e.with_traceback(filtered_tb) from None 123 finally: 124 del filtered_tb /usr/local/lib/python3.10/dist-packages/keras/src/backend/tensorflow/nn.py in categorical_crossentropy(target, output, from_logits, axis) 578 ) 579 if len(target.shape) != len(output.shape): --> 580 raise ValueError( 581 "Arguments `target` and `output` must have the same rank " 582 "(ndim). Received: " ValueError: Arguments `target` and `output` must have the same rank (ndim). Received: target.shape=(None, 10), output.shape=(None, 20, 20, 10)
# Escriu aquí el model corregit
cnn_model = keras.Sequential([
keras.layers.InputLayer(shape=(img_height, img_width, 3)),
Rescaling(1./255),
Conv2D(16, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Dropout(0.2),
Flatten(),
Dense(64, activation='relu'),
Dropout(0.5),
Dense(10, activation='softmax')
], name="CNN_model")
cnn_model.summary()
Model: "CNN_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ rescaling_4 (Rescaling) │ (None, 160, 160, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_3 (Conv2D) │ (None, 160, 160, 16) │ 448 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_3 (MaxPooling2D) │ (None, 80, 80, 16) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_4 (Conv2D) │ (None, 80, 80, 32) │ 4,640 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_4 (MaxPooling2D) │ (None, 40, 40, 32) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_5 (Conv2D) │ (None, 40, 40, 64) │ 18,496 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_5 (MaxPooling2D) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_8 (Dropout) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ flatten_3 (Flatten) │ (None, 25600) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_11 (Dense) │ (None, 64) │ 1,638,464 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_9 (Dropout) │ (None, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_12 (Dense) │ (None, 10) │ 650 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 1,662,698 (6.34 MB)
Trainable params: 1,662,698 (6.34 MB)
Non-trainable params: 0 (0.00 B)
# Compilar model CNN
lr = 1e-4
cnn_model.compile(optimizer=Adam(learning_rate=lr),
loss='categorical_crossentropy',
metrics=['accuracy'])
# Definir callbacks EarlyStopping y ReduceLROnPlateau
callbacks = [
EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
]
# Entrenament del model CNN
hist = cnn_model.fit(
train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks
)
Epoch 1/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 19s 86ms/step - accuracy: 0.1632 - loss: 2.2438 - val_accuracy: 0.3150 - val_loss: 1.9616 - learning_rate: 1.0000e-04 Epoch 2/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.2945 - loss: 1.9902 - val_accuracy: 0.4475 - val_loss: 1.7958 - learning_rate: 1.0000e-04 Epoch 3/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.3547 - loss: 1.8605 - val_accuracy: 0.4715 - val_loss: 1.6791 - learning_rate: 1.0000e-04 Epoch 4/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.3864 - loss: 1.7710 - val_accuracy: 0.5041 - val_loss: 1.5694 - learning_rate: 1.0000e-04 Epoch 5/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.4215 - loss: 1.6727 - val_accuracy: 0.5153 - val_loss: 1.5245 - learning_rate: 1.0000e-04 Epoch 6/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.4552 - loss: 1.6191 - val_accuracy: 0.5469 - val_loss: 1.4411 - learning_rate: 1.0000e-04 Epoch 7/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.4665 - loss: 1.5675 - val_accuracy: 0.5545 - val_loss: 1.4286 - learning_rate: 1.0000e-04 Epoch 8/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.4860 - loss: 1.5211 - val_accuracy: 0.5642 - val_loss: 1.3583 - learning_rate: 1.0000e-04 Epoch 9/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.5091 - loss: 1.4548 - val_accuracy: 0.5851 - val_loss: 1.3425 - learning_rate: 1.0000e-04 Epoch 10/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5228 - loss: 1.4234 - val_accuracy: 0.5810 - val_loss: 1.3240 - learning_rate: 1.0000e-04 Epoch 11/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5385 - loss: 1.3848 - val_accuracy: 0.5856 - val_loss: 1.2756 - learning_rate: 1.0000e-04 Epoch 12/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5465 - loss: 1.3612 - val_accuracy: 0.5897 - val_loss: 1.2538 - learning_rate: 1.0000e-04 Epoch 13/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5510 - loss: 1.3262 - val_accuracy: 0.6035 - val_loss: 1.2348 - learning_rate: 1.0000e-04 Epoch 14/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.5705 - loss: 1.2810 - val_accuracy: 0.6060 - val_loss: 1.2095 - learning_rate: 1.0000e-04 Epoch 15/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.5826 - loss: 1.2522 - val_accuracy: 0.6014 - val_loss: 1.2144 - learning_rate: 1.0000e-04 Epoch 16/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.5848 - loss: 1.2261 - val_accuracy: 0.6050 - val_loss: 1.2004 - learning_rate: 1.0000e-04 Epoch 17/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5905 - loss: 1.2209 - val_accuracy: 0.6106 - val_loss: 1.1886 - learning_rate: 1.0000e-04 Epoch 18/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5986 - loss: 1.1891 - val_accuracy: 0.6162 - val_loss: 1.1687 - learning_rate: 1.0000e-04 Epoch 19/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6089 - loss: 1.1750 - val_accuracy: 0.6295 - val_loss: 1.1514 - learning_rate: 1.0000e-04 Epoch 20/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6114 - loss: 1.1518 - val_accuracy: 0.6182 - val_loss: 1.1745 - learning_rate: 1.0000e-04 Epoch 21/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6204 - loss: 1.1215 - val_accuracy: 0.6259 - val_loss: 1.1507 - learning_rate: 1.0000e-04 Epoch 22/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6305 - loss: 1.0940 - val_accuracy: 0.6228 - val_loss: 1.1378 - learning_rate: 1.0000e-04 Epoch 23/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6462 - loss: 1.0725 - val_accuracy: 0.6356 - val_loss: 1.1246 - learning_rate: 1.0000e-04 Epoch 24/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6474 - loss: 1.0461 - val_accuracy: 0.6239 - val_loss: 1.1419 - learning_rate: 1.0000e-04 Epoch 25/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6558 - loss: 1.0306 - val_accuracy: 0.6346 - val_loss: 1.1268 - learning_rate: 1.0000e-04 Epoch 26/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6473 - loss: 1.0006 - val_accuracy: 0.6340 - val_loss: 1.1180 - learning_rate: 1.0000e-04 Epoch 27/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6728 - loss: 0.9894 - val_accuracy: 0.6279 - val_loss: 1.1199 - learning_rate: 1.0000e-04 Epoch 28/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6724 - loss: 0.9673 - val_accuracy: 0.6259 - val_loss: 1.1202 - learning_rate: 1.0000e-04 Epoch 29/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6695 - loss: 0.9635 - val_accuracy: 0.6376 - val_loss: 1.1031 - learning_rate: 1.0000e-04 Epoch 30/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6841 - loss: 0.9427 - val_accuracy: 0.6371 - val_loss: 1.1219 - learning_rate: 1.0000e-04 Epoch 31/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.6757 - loss: 0.9393 - val_accuracy: 0.6320 - val_loss: 1.1215 - learning_rate: 1.0000e-04 Epoch 32/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.6854 - loss: 0.9238 - val_accuracy: 0.6417 - val_loss: 1.0913 - learning_rate: 1.0000e-04 Epoch 33/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6984 - loss: 0.8957 - val_accuracy: 0.6346 - val_loss: 1.1303 - learning_rate: 1.0000e-04 Epoch 34/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7035 - loss: 0.8665 - val_accuracy: 0.6351 - val_loss: 1.1054 - learning_rate: 1.0000e-04 Epoch 35/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.6992 - loss: 0.8609 - val_accuracy: 0.6356 - val_loss: 1.1139 - learning_rate: 1.0000e-04 Epoch 36/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7023 - loss: 0.8438 - val_accuracy: 0.6437 - val_loss: 1.0945 - learning_rate: 1.0000e-04 Epoch 37/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7199 - loss: 0.8040 - val_accuracy: 0.6371 - val_loss: 1.1160 - learning_rate: 1.0000e-04 Epoch 38/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.7281 - loss: 0.7922 - val_accuracy: 0.6468 - val_loss: 1.0915 - learning_rate: 2.0000e-05 Epoch 39/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.7361 - loss: 0.7760 - val_accuracy: 0.6442 - val_loss: 1.0933 - learning_rate: 2.0000e-05 Epoch 40/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7363 - loss: 0.7654 - val_accuracy: 0.6483 - val_loss: 1.0896 - learning_rate: 2.0000e-05 Epoch 41/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 44ms/step - accuracy: 0.7318 - loss: 0.7714 - val_accuracy: 0.6473 - val_loss: 1.0880 - learning_rate: 2.0000e-05 Epoch 42/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.7316 - loss: 0.7611 - val_accuracy: 0.6498 - val_loss: 1.0865 - learning_rate: 2.0000e-05 Epoch 43/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.7420 - loss: 0.7376 - val_accuracy: 0.6422 - val_loss: 1.0896 - learning_rate: 2.0000e-05 Epoch 44/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7457 - loss: 0.7375 - val_accuracy: 0.6453 - val_loss: 1.1023 - learning_rate: 2.0000e-05 Epoch 45/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.7372 - loss: 0.7537 - val_accuracy: 0.6442 - val_loss: 1.0981 - learning_rate: 2.0000e-05 Epoch 46/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7363 - loss: 0.7537 - val_accuracy: 0.6463 - val_loss: 1.0938 - learning_rate: 2.0000e-05 Epoch 47/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.7497 - loss: 0.7376 - val_accuracy: 0.6458 - val_loss: 1.0868 - learning_rate: 2.0000e-05 Epoch 48/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7497 - loss: 0.7166 - val_accuracy: 0.6468 - val_loss: 1.0893 - learning_rate: 4.0000e-06 Epoch 49/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7545 - loss: 0.7222 - val_accuracy: 0.6514 - val_loss: 1.0878 - learning_rate: 4.0000e-06 Epoch 50/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.7492 - loss: 0.7250 - val_accuracy: 0.6488 - val_loss: 1.0862 - learning_rate: 4.0000e-06 Epoch 51/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7487 - loss: 0.7311 - val_accuracy: 0.6488 - val_loss: 1.0878 - learning_rate: 4.0000e-06 Epoch 52/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.7558 - loss: 0.7212 - val_accuracy: 0.6488 - val_loss: 1.0907 - learning_rate: 4.0000e-06 Epoch 53/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7463 - loss: 0.7172 - val_accuracy: 0.6493 - val_loss: 1.0910 - learning_rate: 4.0000e-06 Epoch 54/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7562 - loss: 0.7114 - val_accuracy: 0.6509 - val_loss: 1.0897 - learning_rate: 4.0000e-06 Epoch 55/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7530 - loss: 0.7201 - val_accuracy: 0.6509 - val_loss: 1.0896 - learning_rate: 4.0000e-06 Epoch 56/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.7504 - loss: 0.7252 - val_accuracy: 0.6504 - val_loss: 1.0865 - learning_rate: 1.0000e-06 Epoch 57/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7503 - loss: 0.7182 - val_accuracy: 0.6488 - val_loss: 1.0855 - learning_rate: 1.0000e-06 Epoch 58/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7595 - loss: 0.7151 - val_accuracy: 0.6493 - val_loss: 1.0860 - learning_rate: 1.0000e-06 Epoch 59/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7536 - loss: 0.7137 - val_accuracy: 0.6488 - val_loss: 1.0867 - learning_rate: 1.0000e-06 Epoch 60/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7549 - loss: 0.7104 - val_accuracy: 0.6478 - val_loss: 1.0862 - learning_rate: 1.0000e-06 Epoch 61/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.7503 - loss: 0.7150 - val_accuracy: 0.6504 - val_loss: 1.0873 - learning_rate: 1.0000e-06 Epoch 62/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7400 - loss: 0.7230 - val_accuracy: 0.6478 - val_loss: 1.0867 - learning_rate: 1.0000e-06 Epoch 63/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7523 - loss: 0.7170 - val_accuracy: 0.6483 - val_loss: 1.0866 - learning_rate: 1.0000e-06 Epoch 64/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7554 - loss: 0.7114 - val_accuracy: 0.6483 - val_loss: 1.0869 - learning_rate: 1.0000e-06 Epoch 65/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.7521 - loss: 0.7211 - val_accuracy: 0.6488 - val_loss: 1.0868 - learning_rate: 1.0000e-06 Epoch 66/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.7496 - loss: 0.7196 - val_accuracy: 0.6488 - val_loss: 1.0870 - learning_rate: 1.0000e-06 Epoch 67/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 45ms/step - accuracy: 0.7376 - loss: 0.7238 - val_accuracy: 0.6483 - val_loss: 1.0875 - learning_rate: 1.0000e-06
# Resultats
grafica_accuracy_loss(hist, lr)
# Validació al conjunt de prova
test_loss, test_accuracy = cnn_model.evaluate(test_dataset)
print(f"Pèrdua al conjunt de prova: {test_loss:.4f}")
print(f"Accuracy al conjunt de prova: {test_accuracy:.4f}")
31/31 ━━━━━━━━━━━━━━━━━━━━ 1s 36ms/step - accuracy: 0.6587 - loss: 1.0985 Pèrdua al conjunt de prova: 1.1145 Accuracy al conjunt de prova: 0.6490
Tot i que la CNN entrenada és força efectiva, podem intentar millorar-ne la generalització augmentant artificialment la mida i la diversitat del conjunt d'entrenament mitjançant tècniques d'augmentació de dades. L'augmentació consisteix a aplicar transformacions aleatòries a les imatges (girs, rotacions, zoom, etc.) de manera que el model rebi variants de les imatges originals a cada època, simulant tenir més dades
Keras proporciona capes de preprocessament d'imatge que realitzen aquestes transformacions de manera eficient durant l'entrenament, per exemple RandomFlip, RandomRotation, RandomZoom entre d'altres.
Aquí farem servir algunes d'aquestes capes per implementar l'augment. En particular, provarem de voltejar horitzontalment les imatges i aplicar petites rotacions aleatòries.
RandomFlip que voltegi aleatòriament les imatges horitzontalment.RandomRotation amb factor de rotació de 0.1 (±10%).from tensorflow.keras import layers
# Model d'augmentació de dades
data_augmentation = keras.Sequential([
layers.RandomFlip("horizontal"), # Volteig horitzontal
layers.RandomRotation(0.1) # Rotació ±10%
], name="data_augmentation")
# Prendre un batch d'entrenament i obtenir una imatge
for images, labels in train_dataset.take(1):
sample_image = images[0]
break
# Aplicar augmentació diverses vegades i visualitzar
plt.figure(figsize=(10, 2))
for i in range(5):
augmented_image = data_augmentation(tf.expand_dims(sample_image, 0))
plt.subplot(1, 5, i+1)
plt.imshow(augmented_image[0].numpy().astype("uint8"))
plt.axis("off")
plt.tight_layout()
plt.show()
Ara incorporarem la capa d’augmentació al model CNN per entrenar-la amb dades augmentades a cada època.
Rescaling i la primera Conv2D del model CNN anterior. És a dir, modifica l'arquitectura perquè les imatges d'entrada, després de ser reescalades, passin per les capes de flip i rotation aleatòries, i després continuïn per la CNN. A continuació, compila i entrena el nou model CNN augmentat seguint les mateixes indicacions que en l'exercici anterior, excepte que aquesta vegada farem servir un learning rate inicial més gran (1e-3). Mantingues EarlyStopping, ReduceLROnPlateau, 100 èpoques, etc. Després avalua el model final al conjunt de test.
Comenta les diferències amb el model sense augmentació en termes de:
# Model CNN amb augmentació de dades
cnn_aug_model = tf.keras.Sequential([
tf.keras.layers.InputLayer(shape=(img_height, img_width, 3)),
layers.Rescaling(1./255),
data_augmentation,
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dropout(0.5),
layers.Dense(10, activation='softmax')
], name="CNN_aug_model")
cnn_aug_model.summary()
Model: "CNN_aug_model"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ rescaling_5 (Rescaling) │ (None, 160, 160, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ data_augmentation (Sequential) │ (None, 160, 160, 3) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_6 (Conv2D) │ (None, 160, 160, 16) │ 448 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_6 (MaxPooling2D) │ (None, 80, 80, 16) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_7 (Conv2D) │ (None, 80, 80, 32) │ 4,640 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_7 (MaxPooling2D) │ (None, 40, 40, 32) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_8 (Conv2D) │ (None, 40, 40, 64) │ 18,496 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ max_pooling2d_8 (MaxPooling2D) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_10 (Dropout) │ (None, 20, 20, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ flatten_4 (Flatten) │ (None, 25600) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_13 (Dense) │ (None, 64) │ 1,638,464 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dropout_11 (Dropout) │ (None, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_14 (Dense) │ (None, 10) │ 650 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 1,662,698 (6.34 MB)
Trainable params: 1,662,698 (6.34 MB)
Non-trainable params: 0 (0.00 B)
# Compilació de la xarxa
lr = 1e-3
cnn_aug_model.compile(
optimizer=Adam(learning_rate=lr),
loss='categorical_crossentropy',
metrics=['accuracy']
)
callbacks = [
tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True),
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6)
]
# Entrenament
hist_aug = cnn_aug_model.fit(
train_dataset,
validation_data=val_dataset,
epochs=100,
callbacks=callbacks
)
Epoch 1/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 11s 51ms/step - accuracy: 0.2032 - loss: 2.1754 - val_accuracy: 0.4358 - val_loss: 1.7140 - learning_rate: 0.0010 Epoch 2/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.3574 - loss: 1.8574 - val_accuracy: 0.5082 - val_loss: 1.5215 - learning_rate: 0.0010 Epoch 3/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.4155 - loss: 1.6918 - val_accuracy: 0.5601 - val_loss: 1.3669 - learning_rate: 0.0010 Epoch 4/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.4523 - loss: 1.6082 - val_accuracy: 0.5973 - val_loss: 1.2956 - learning_rate: 0.0010 Epoch 5/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.4773 - loss: 1.5498 - val_accuracy: 0.6055 - val_loss: 1.2546 - learning_rate: 0.0010 Epoch 6/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 50ms/step - accuracy: 0.4938 - loss: 1.5010 - val_accuracy: 0.6259 - val_loss: 1.1793 - learning_rate: 0.0010 Epoch 7/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.5044 - loss: 1.4593 - val_accuracy: 0.6182 - val_loss: 1.1757 - learning_rate: 0.0010 Epoch 8/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5191 - loss: 1.4346 - val_accuracy: 0.5882 - val_loss: 1.2816 - learning_rate: 0.0010 Epoch 9/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5325 - loss: 1.4101 - val_accuracy: 0.6488 - val_loss: 1.1555 - learning_rate: 0.0010 Epoch 10/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5236 - loss: 1.4081 - val_accuracy: 0.6162 - val_loss: 1.1828 - learning_rate: 0.0010 Epoch 11/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.5298 - loss: 1.3845 - val_accuracy: 0.6228 - val_loss: 1.1820 - learning_rate: 0.0010 Epoch 12/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5469 - loss: 1.3380 - val_accuracy: 0.6478 - val_loss: 1.1212 - learning_rate: 0.0010 Epoch 13/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.5509 - loss: 1.3053 - val_accuracy: 0.6651 - val_loss: 1.0803 - learning_rate: 0.0010 Epoch 14/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.5654 - loss: 1.2917 - val_accuracy: 0.6264 - val_loss: 1.1992 - learning_rate: 0.0010 Epoch 15/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.5673 - loss: 1.2604 - val_accuracy: 0.6376 - val_loss: 1.1459 - learning_rate: 0.0010 Epoch 16/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.5797 - loss: 1.2654 - val_accuracy: 0.6871 - val_loss: 1.0211 - learning_rate: 0.0010 Epoch 17/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5704 - loss: 1.2551 - val_accuracy: 0.6702 - val_loss: 1.0626 - learning_rate: 0.0010 Epoch 18/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5947 - loss: 1.2186 - val_accuracy: 0.6748 - val_loss: 1.0450 - learning_rate: 0.0010 Epoch 19/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.5971 - loss: 1.2014 - val_accuracy: 0.6636 - val_loss: 1.0452 - learning_rate: 0.0010 Epoch 20/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.5944 - loss: 1.2056 - val_accuracy: 0.6641 - val_loss: 1.0678 - learning_rate: 0.0010 Epoch 21/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6162 - loss: 1.1679 - val_accuracy: 0.6901 - val_loss: 1.0062 - learning_rate: 0.0010 Epoch 22/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6194 - loss: 1.1374 - val_accuracy: 0.6718 - val_loss: 1.0484 - learning_rate: 0.0010 Epoch 23/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6188 - loss: 1.1289 - val_accuracy: 0.6575 - val_loss: 1.0955 - learning_rate: 0.0010 Epoch 24/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6131 - loss: 1.1409 - val_accuracy: 0.6998 - val_loss: 0.9658 - learning_rate: 0.0010 Epoch 25/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.6329 - loss: 1.1140 - val_accuracy: 0.6820 - val_loss: 1.0303 - learning_rate: 0.0010 Epoch 26/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6296 - loss: 1.1048 - val_accuracy: 0.6957 - val_loss: 0.9878 - learning_rate: 0.0010 Epoch 27/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6265 - loss: 1.1197 - val_accuracy: 0.6983 - val_loss: 0.9928 - learning_rate: 0.0010 Epoch 28/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6320 - loss: 1.1026 - val_accuracy: 0.7074 - val_loss: 0.9503 - learning_rate: 0.0010 Epoch 29/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.6290 - loss: 1.0996 - val_accuracy: 0.7013 - val_loss: 0.9811 - learning_rate: 0.0010 Epoch 30/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6402 - loss: 1.0925 - val_accuracy: 0.6769 - val_loss: 1.0365 - learning_rate: 0.0010 Epoch 31/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6398 - loss: 1.0821 - val_accuracy: 0.7059 - val_loss: 0.9707 - learning_rate: 0.0010 Epoch 32/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6408 - loss: 1.0515 - val_accuracy: 0.7034 - val_loss: 0.9576 - learning_rate: 0.0010 Epoch 33/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 50ms/step - accuracy: 0.6514 - loss: 1.0410 - val_accuracy: 0.6707 - val_loss: 1.0475 - learning_rate: 0.0010 Epoch 34/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 50ms/step - accuracy: 0.6550 - loss: 1.0288 - val_accuracy: 0.7283 - val_loss: 0.8876 - learning_rate: 2.0000e-04 Epoch 35/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6691 - loss: 0.9798 - val_accuracy: 0.7161 - val_loss: 0.9213 - learning_rate: 2.0000e-04 Epoch 36/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6799 - loss: 0.9512 - val_accuracy: 0.7217 - val_loss: 0.8982 - learning_rate: 2.0000e-04 Epoch 37/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6806 - loss: 0.9396 - val_accuracy: 0.7130 - val_loss: 0.9157 - learning_rate: 2.0000e-04 Epoch 38/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 8s 52ms/step - accuracy: 0.6780 - loss: 0.9455 - val_accuracy: 0.7283 - val_loss: 0.8861 - learning_rate: 2.0000e-04 Epoch 39/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6822 - loss: 0.9566 - val_accuracy: 0.7253 - val_loss: 0.8906 - learning_rate: 2.0000e-04 Epoch 40/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6900 - loss: 0.9335 - val_accuracy: 0.7304 - val_loss: 0.8679 - learning_rate: 2.0000e-04 Epoch 41/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6849 - loss: 0.9290 - val_accuracy: 0.7258 - val_loss: 0.8849 - learning_rate: 2.0000e-04 Epoch 42/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6855 - loss: 0.9317 - val_accuracy: 0.7141 - val_loss: 0.9165 - learning_rate: 2.0000e-04 Epoch 43/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 50ms/step - accuracy: 0.6908 - loss: 0.9338 - val_accuracy: 0.7232 - val_loss: 0.8930 - learning_rate: 2.0000e-04 Epoch 44/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6844 - loss: 0.9300 - val_accuracy: 0.7202 - val_loss: 0.9150 - learning_rate: 2.0000e-04 Epoch 45/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6890 - loss: 0.9134 - val_accuracy: 0.7273 - val_loss: 0.8932 - learning_rate: 2.0000e-04 Epoch 46/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6841 - loss: 0.9220 - val_accuracy: 0.7334 - val_loss: 0.8729 - learning_rate: 4.0000e-05 Epoch 47/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 8s 51ms/step - accuracy: 0.6889 - loss: 0.9195 - val_accuracy: 0.7314 - val_loss: 0.8701 - learning_rate: 4.0000e-05 Epoch 48/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.7026 - loss: 0.8994 - val_accuracy: 0.7319 - val_loss: 0.8671 - learning_rate: 4.0000e-05 Epoch 49/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.7033 - loss: 0.8798 - val_accuracy: 0.7334 - val_loss: 0.8660 - learning_rate: 4.0000e-05 Epoch 50/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6926 - loss: 0.9023 - val_accuracy: 0.7324 - val_loss: 0.8714 - learning_rate: 4.0000e-05 Epoch 51/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6892 - loss: 0.8945 - val_accuracy: 0.7309 - val_loss: 0.8659 - learning_rate: 4.0000e-05 Epoch 52/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 49ms/step - accuracy: 0.6981 - loss: 0.8979 - val_accuracy: 0.7324 - val_loss: 0.8635 - learning_rate: 4.0000e-05 Epoch 53/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 46ms/step - accuracy: 0.6908 - loss: 0.9067 - val_accuracy: 0.7339 - val_loss: 0.8683 - learning_rate: 4.0000e-05 Epoch 54/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.7019 - loss: 0.8919 - val_accuracy: 0.7273 - val_loss: 0.8768 - learning_rate: 4.0000e-05 Epoch 55/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.7037 - loss: 0.8936 - val_accuracy: 0.7283 - val_loss: 0.8736 - learning_rate: 4.0000e-05 Epoch 56/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 8s 51ms/step - accuracy: 0.7012 - loss: 0.8852 - val_accuracy: 0.7309 - val_loss: 0.8773 - learning_rate: 4.0000e-05 Epoch 57/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6907 - loss: 0.9069 - val_accuracy: 0.7329 - val_loss: 0.8691 - learning_rate: 4.0000e-05 Epoch 58/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6940 - loss: 0.9004 - val_accuracy: 0.7339 - val_loss: 0.8700 - learning_rate: 8.0000e-06 Epoch 59/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 48ms/step - accuracy: 0.6949 - loss: 0.8882 - val_accuracy: 0.7360 - val_loss: 0.8675 - learning_rate: 8.0000e-06 Epoch 60/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6982 - loss: 0.8936 - val_accuracy: 0.7339 - val_loss: 0.8699 - learning_rate: 8.0000e-06 Epoch 61/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 50ms/step - accuracy: 0.7003 - loss: 0.8639 - val_accuracy: 0.7350 - val_loss: 0.8663 - learning_rate: 8.0000e-06 Epoch 62/100 148/148 ━━━━━━━━━━━━━━━━━━━━ 7s 47ms/step - accuracy: 0.6894 - loss: 0.9086 - val_accuracy: 0.7365 - val_loss: 0.8679 - learning_rate: 8.0000e-06
# Resultats
grafica_accuracy_loss(hist_aug, lr)
test_loss, test_accuracy = cnn_aug_model.evaluate(test_dataset)
print(f"Pèrdua al conjunt de prova: {test_loss:.4f}")
print(f"Accuracy al conjunt de prova: {test_accuracy:.4f}")
31/31 ━━━━━━━━━━━━━━━━━━━━ 2s 54ms/step - accuracy: 0.7295 - loss: 0.8769 Pèrdua al conjunt de prova: 0.8662 Accuracy al conjunt de prova: 0.7341
Fins ara hem treballat en un problema de “classificació”. Als apartats restants abordarem un problema diferent però relacionat amb la visió per computador: la superresolució d'imatges. La superresolució consisteix a generar una imatge d'alta resolució (HR) a partir d'una de baixa resolució (LR), intentant recuperar o inferir els detalls perduts en reduir la imatge. És un problema d'aprenentatge supervisat on el model aprèn una transformació imatge->imatge.
Farem servir novament la base de dades Imagenette per crear exemples d'entrenament: a partir de cada imatge original (320px) generarem una versió reduïda (p. ex. 80px) que servirà com a entrada, tenint com a sortida esperada la imatge original. Així, el model aprendrà a mapejar de baixa a alta resolució. En lloc d'una xarxa convolucional per a la classificació, necessitarem una xarxa capaç de processar una imatge d'entrada i produir una imatge de sortida. Les capes transposades de convolució (Conv2DTranspose) o les tècniques de upsampling són les peces clau per a aquests models, ja que permeten augmentar les dimensions espacials de les dades.
A continuació, crearem el conjunt de dades per a superresolució i entrenarem una CNN de superresolució simple.
Primer generarem els parells d'imatges d'entrenament i de validació per a superresolució. Partirem de les imatges originals d'entrenament (i validació) a la seva resolució completa i, per simplicitat, les redimensionarem a una mida fixa de 320×320 (ignorant la relació d'aspecte original, similar al fet en classificació) per utilitzar-les com a imatges de referència d'alta resolució (HR), i les reduirem a 1/4 de la seva mida (aprox). Per realitzar ambdues transformacions (ajust de les imatges originals a 320x320 i la seva reducció a 80x80) utilitzarem mètodes d'interpolació bicúbica per simular imatges degrades suaument.
Procedim a obtenir les rutes d'imatges d'entrenament i validació, després creem un Dataset aplicant la funció de mapeig que realitza la lectura i la transformació:
/train/ (per crear els subconjunts d'entrenament i validació) a la seva mida completa (320px) com a HR, i genera imatges LR reduint-les a 1/4 de la seva mida lineal (80×80). Es demana:
import random
# Funció per preprocessar una imatge per resolució, genera un parell d'imatges:
# una de baixa resolució (LR) i la corresponent en alta resolució (HR)
def preprocess_image_for_sr(filepath):
# Llegir el fitxer d'imatge en format JPEG des de la ruta indicada
img = tf.io.read_file(filepath)
img = tf.image.decode_jpeg(img, channels=3)
# Convertir imatge a float32 i normalitzar els valors dels píxels a l'interval [0,1]
img = tf.image.convert_image_dtype(img, tf.float32)
# Redimensió de la imatge a 320x320 per obtenir la versió HR
# Redimensió de la imatge a 80x80 per obtenir la versió LR
# Es fa servir el mètode BICUBIC per mantenir els detalls
hr = tf.image.resize(img, [320, 320], method=tf.image.ResizeMethod.BICUBIC)
lr = tf.image.resize(img, [80, 80], method=tf.image.ResizeMethod.BICUBIC)
# Assegura que els valors dels píxels estan dintre de [0,1]
hr = tf.clip_by_value(hr, 0.0, 1.0)
lr = tf.clip_by_value(lr, 0.0, 1.0)
return lr, hr
# Creació de llistes buides per emmagatzemar les rutes dels fitxers d'entrenament i validació
train_files = []
val_files = []
for cls in train_classes:
cls_files = sorted((train_dir/cls).glob("*.*"))
# Separació de fitxers 20% per validació i 80% per entrenament
split_idx = int(len(cls_files) * 0.2)
val_files += [str(p) for p in cls_files[:split_idx]]
train_files += [str(p) for p in cls_files[split_idx:]]
random.shuffle(train_files)
random.shuffle(val_files)
# Creació dels Datasets per entrenament i validació
# Cada element està format per un parell (LR, HR)
train_sr_ds = tf.data.Dataset.from_tensor_slices(train_files).map(preprocess_image_for_sr, num_parallel_calls=tf.data.AUTOTUNE).batch(32).prefetch(tf.data.AUTOTUNE)
val_sr_ds = tf.data.Dataset.from_tensor_slices(val_files).map(preprocess_image_for_sr, num_parallel_calls=tf.data.AUTOTUNE).batch(32).prefetch(tf.data.AUTOTUNE)
Procedim a verificar la base de dades:
# Verificació d'un parell LR-HR
for lr_batch, hr_batch in train_sr_ds.take(1):
lr_img = lr_batch[0].numpy()
hr_img = hr_batch[0].numpy()
print("LR shape:", lr_img.shape, "HR shape:", hr_img.shape)
print("LR pixel range:", lr_img.min(), "-", lr_img.max())
print("HR pixel range:", hr_img.min(), "-", hr_img.max())
break
LR shape: (80, 80, 3) HR shape: (320, 320, 3) LR pixel range: 0.0 - 1.0 HR pixel range: 0.0 - 1.0
Ara definim un model CNN per a Superresolució. Optarem per una arquitectura senzilla amb UpSampling2D per escalar la imatge gradualment. Implementem el model:
Upsampling2D (factor=2) o Conv2DTranspose amb stride=2 per doblar l'amplada i alçada (80 -> 160).
from tensorflow.keras.models import Sequential
# Modelo de Superrsolució (80px -> 320px)
sr_model = Sequential([
Input(shape=(80, 80, 3)),
Conv2D(64, kernel_size=3, padding='same', activation='relu'),
Conv2D(64, kernel_size=3, padding='same', activation='relu'),
UpSampling2D(size=(2, 2)), # 80x80 → 160x160
Conv2D(64, kernel_size=3, padding='same', activation='relu'),
UpSampling2D(size=(2, 2)), # 160x160 → 320x320
Conv2D(3, kernel_size=3, padding='same', activation='sigmoid')
])
sr_model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ conv2d_9 (Conv2D) │ (None, 80, 80, 64) │ 1,792 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_10 (Conv2D) │ (None, 80, 80, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ up_sampling2d (UpSampling2D) │ (None, 160, 160, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_11 (Conv2D) │ (None, 160, 160, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ up_sampling2d_1 (UpSampling2D) │ (None, 320, 320, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_12 (Conv2D) │ (None, 320, 320, 3) │ 1,731 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 77,379 (302.26 KB)
Trainable params: 77,379 (302.26 KB)
Non-trainable params: 0 (0.00 B)
# Compilació de la xarxa
lr = 1e-3
sr_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
loss='mse')
callbacks = [
EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
]
# Entrenament de la xarxa
hist = sr_model.fit(
train_sr_ds,
validation_data=val_sr_ds,
epochs=100,
callbacks=callbacks
)
Epoch 1/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 73s 244ms/step - loss: 0.0265 - val_loss: 0.0065 Epoch 2/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0063 - val_loss: 0.0061 Epoch 3/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0060 - val_loss: 0.0061 Epoch 4/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0058 - val_loss: 0.0061 Epoch 5/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0057 - val_loss: 0.0057 Epoch 6/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0057 - val_loss: 0.0058 Epoch 7/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0057 - val_loss: 0.0057 Epoch 8/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0056 - val_loss: 0.0057 Epoch 9/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0056 - val_loss: 0.0056 Epoch 10/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0056 - val_loss: 0.0056 Epoch 11/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0055 - val_loss: 0.0056 Epoch 12/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0056 - val_loss: 0.0056 Epoch 13/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0055 - val_loss: 0.0056 Epoch 14/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0055 - val_loss: 0.0055 Epoch 15/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 171ms/step - loss: 0.0055 - val_loss: 0.0056 Epoch 16/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0055 - val_loss: 0.0056 Epoch 17/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0055 - val_loss: 0.0055 Epoch 18/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0055 - val_loss: 0.0055 Epoch 19/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 41s 171ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 20/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0056 - val_loss: 0.0055 Epoch 21/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 22/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 23/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 24/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 25/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 26/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0055 Epoch 27/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0054 Epoch 28/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0054 Epoch 29/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0054 Epoch 30/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0054 Epoch 31/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0054 - val_loss: 0.0054 Epoch 32/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0054 - val_loss: 0.0054 Epoch 33/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 34/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 35/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 36/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 37/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 38/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 39/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 40/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 41/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 42/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 43/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 44/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0054 Epoch 45/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 46/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 47/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 48/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 49/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 50/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 51/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0053 - val_loss: 0.0053 Epoch 52/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 53/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 54/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 55/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 56/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 57/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 58/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 170ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 59/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 60/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 61/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 62/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 63/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 64/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 65/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 66/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 67/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 68/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 69/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 70/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 71/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 72/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 73/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 74/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 75/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 76/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 77/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 78/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 79/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 80/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 81/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 82/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0052 - val_loss: 0.0053 Epoch 83/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 84/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 85/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 86/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 87/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 88/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 89/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 90/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 91/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 92/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 93/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0052 Epoch 94/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0051 - val_loss: 0.0052 Epoch 95/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0051 - val_loss: 0.0053 Epoch 96/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0051 - val_loss: 0.0052 Epoch 97/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 168ms/step - loss: 0.0051 - val_loss: 0.0052 Epoch 98/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 169ms/step - loss: 0.0051 - val_loss: 0.0052 Epoch 99/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0052 Epoch 100/100 237/237 ━━━━━━━━━━━━━━━━━━━━ 40s 167ms/step - loss: 0.0051 - val_loss: 0.0053
def grafica_loss(hist, lr):
"""
Funció auxiliar per dibuixar les corbes de loss de l'entrenament i validació.
"""
epochs = range(1, len(hist.history['loss']) + 1)
plt.figure(figsize=(6, 5))
plt.plot(epochs, hist.history['loss'], label='Pèrdua Entrenament')
plt.plot(epochs, hist.history['val_loss'], label='Pèrdua Validació')
plt.title(f"Pèrdua (Learning rate = {lr})")
plt.xlabel("Època")
plt.ylabel("Pèrdua")
plt.legend()
plt.tight_layout()
plt.show()
# Resultats
grafica_loss(hist, lr)
Amb el model de superresolució entrenat, avaluarem el seu rendiment tant quantitativament (amb mètriques) com qualitativament (visualització d’imatges). Per fer -ho, realitzarem la inferència: és a dir, prendre imatges LR del conjunt de test, passar-les pel model sr_model per obtenir imatges supers resoltes (SR) i comparar-les amb les imatges HR originals. Calcularem la mètrica PSNR (Peak Signal-to-Notise Ratio) per quantificar la qualitat.
El PSNR es mesura en decibels (dB) i els valors superiors indiquen una major similitud amb la imatge original (per exemple,> 30 dB sol indicar una qualitat de reconstrucció molt bona).
També visualitzarem alguns exemples, on mostraren la imatge de baixa resolució (LR), la super resolta per la CNN (SR) i l’original d’alta resolució (HR), per inspeccionar els detalls amb l’ull nu.
tf.image.resize() juntament amb method=tf.image.ResizeMethod.NEAREST_NEIGHBOR), la imatge generada pel model (SR) i la imatge HR original, per poder comparar la qualitat visualment a part del valor de la PSNR.# Preparar conjunt de test per a superresolució (a partir del directori /val)
test_files = list(val_dir.glob("**/*.*"))
test_sr_ds = tf.data.Dataset.from_tensor_slices(
tf.constant([str(file) for file in test_files], dtype=tf.string)
).map(preprocess_image_for_sr, num_parallel_calls=tf.data.AUTOTUNE) \
.batch(1) \
.prefetch(tf.data.AUTOTUNE)
# Llista per emmagatzemar la PSNR de cada imatge
psnr_list = []
# Recórrer el dataset de test (batch size = 1)
for lr_batch, hr_batch in test_sr_ds:
# Generar la imatge de superresolució a partir de la imatge LR
sr_pred = sr_model.predict(lr_batch, verbose=0)
# Treure la dimensió de batch i assegurar que els valors estiguin a [0,1]
sr_img = tf.squeeze(sr_pred, axis=0)
sr_img = tf.clip_by_value(sr_img, 0.0, 1.0)
# Convertir la imatge HR a array de numpy (llevem la dimensió de batch)
hr_img = tf.squeeze(hr_batch, axis=0).numpy()
# Calcular la PSNR entre la imatge SR generada i la imatge HR original
psnr_value = tf.image.psnr(sr_img, hr_img, max_val=1.0)
# Emmagatzemar el valor de PSNR a la llista
psnr_list.append(psnr_value.numpy())
# Calcular la PSNR mitjana del conjunt de test
avg_psnr = np.mean(psnr_list)
print("PSNR mitjana al conjunt de test: {:.2f} dB".format(avg_psnr))
PSNR mitjana al conjunt de test: 24.10 dB
import math
# Escollir algunes imatges de test aleatòries per a visualització
sample_files = random.sample(test_files, 3)
# Per a cada imatge
for file in sample_files:
# Llegir i processar una imatge de test
lr, hr = preprocess_image_for_sr(str(file))
# Generar superresolució
lr_expanded = tf.expand_dims(lr, axis=0)
sr = sr_model.predict(lr_expanded, verbose=0)
sr = tf.squeeze(sr, axis=0)
# Calcular PSNR
psnr_img = tf.image.psnr(sr, hr, max_val=1.0).numpy()
# Preparar imatges per visualitzar de la mateixa mida
lr_up = tf.image.resize(lr, [320, 320], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
# Les imatges estan a [0,1], escalar a [0,255] per mostrar correctament
lr_disp = (lr_up.numpy() * 255).astype(np.uint8)
sr_disp = (sr.numpy() * 255).astype(np.uint8)
hr_disp = (hr.numpy() * 255).astype(np.uint8)
# Mostrar les imatges
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(lr_disp)
plt.title("LR (ampliada)")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(sr_disp)
plt.title("SR\nPSNR: {:.2f} dB".format(psnr_img))
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(hr_disp)
plt.title("HR original")
plt.axis("off")
plt.tight_layout()
plt.show()
Els avenços recents en superresolució han produït arquitectures més complexes (p. ex., basades en xarxes generatives adversàries) que aconsegueixen resultats notablement millors, a costa d'un entrenament costós. Un d'aquests models és ESRGAN (Enhanced Super-Resolution GAN per Xintao Wang et al.), entrenat en grans bases de dades d'imatges HD (com DIV2K) per aconseguir superresolució 4x de gran fidelitat.
Farem servir un model pre-entrenat d'ESRGAN disponible via TensorFlow Hub
Aquest model ha estat entrenat al conjunt DIV2K (imatges d'alta qualitat) amb degradació bicúbica, per la qual cosa està especialitzat en produir imatges 4× més grans amb detall notable.
Farem servir ESRGAN per aplicar superresolució a les mateixes imatges de test i compararem els resultats amb el nostre model implementat a l'exercici 5.
"https://tfhub.dev/captain-pool/esrgan-tf2/1").import tensorflow_hub as hub
# Carreguem el model ESRGAN des de TF-Hub:
esrgan = hub.load("https://tfhub.dev/captain-pool/esrgan-tf2/1")
print("Model ESRGAN carregat.")
Model ESRGAN carregat.
# Llista per emmagatzemar la PSNR de cada imatge usant ESRGAN
psnr_esrgan_list = []
# Recórrer el dataset de test (batch size = 1)
for lr_batch, hr_batch in test_sr_ds:
# Preparar la imatge LR per ESRGAN:
# ESRGAN espera entrada en rang [0,255] com a float32, i test_sr_ds té imatges a [0,1]
lr_255 = tf.clip_by_value(lr_batch * 255.0, 0.0, 255.0)
# Generar la imatge de superresolució amb ESRGAN
sr_esrgan_255 = esrgan(lr_255)
# Treure la dimensió de batch i normalitzar la sortida a [0,1]
sr_esrgan = tf.clip_by_value(sr_esrgan_255 / 255.0, 0.0, 1.0)
# Obtenir la imatge HR corresponent (llevem la dimensió de batch)
hr_img = tf.squeeze(hr_batch, axis=0)
# Calcular la PSNR entre la imatge generada per ESRGAN i la imatge HR original
psnr_val = tf.image.psnr(sr_esrgan, hr_img, max_val=1.0)
psnr_esrgan_list.append(psnr_val.numpy())
# Calcular la PSNR mitjana en el conjunt de test amb ESRGAN
avg_psnr_esrgan = np.mean(psnr_esrgan_list)
print("PSNR mitjana en el conjunt de test amb ESRGAN: {:.2f} dB".format(avg_psnr_esrgan))
PSNR mitjana en el conjunt de test amb ESRGAN: 19.13 dB
# Apliquem ESRGAN a les mateixes imatges d'exemple que fem servir amb el nostre model
sample_files = random.sample(test_files, 3)
# Per a cada imatge
for file in sample_files:
# Preparar la imatge LR de test
lr, hr = preprocess_image_for_sr(str(file))
lr_esrgan = tf.expand_dims(lr * 255.0, axis=0)
sr_esrgan_255 = esrgan(lr_esrgan)
# Convertir sortida a [0,1] float
sr_esrgan = tf.squeeze(sr_esrgan_255, axis=0) / 255.0
sr_esrgan = tf.clip_by_value(sr_esrgan, 0.0, 1.0)
# Calcular PSNR comparat amb HR
psnr_esrgan = tf.image.psnr(sr_esrgan, hr, max_val=1.0).numpy()
# Visualitzar comparativa
lr_up = tf.image.resize(lr, [320, 320], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
lr_disp = (lr_up.numpy() * 255).astype(np.uint8)
sr_disp = (sr_esrgan.numpy() * 255).astype(np.uint8)
hr_disp = (hr.numpy() * 255).astype(np.uint8)
plt.figure(figsize=(16, 4))
plt.subplot(1, 3, 1)
plt.imshow(lr_disp)
plt.title("LR (ampliada)")
plt.axis("off")
plt.subplot(1, 3, 2)
plt.imshow(sr_disp)
plt.title("ESRGAN SR\nPSNR: {:.2f} dB".format(psnr_esrgan))
plt.axis("off")
plt.subplot(1, 3, 3)
plt.imshow(hr_disp)
plt.title("HR original")
plt.axis("off")
plt.tight_layout()
plt.show()
En aquesta pràctica hem explorat tant la classificació d’imatges amb xarxes neuronals (denses vs convolucionals) com la superresolució amb CNNs, aplicant-ho tot sobre la base de dades d’Imagenette.